/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/spmm.cu
 * \brief SPMM C APIs and definitions.
 */
#include <dgl/array.h>
#include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace cuda;

namespace aten {

/*!
 * \brief Determine whether cusparse SpMM function is applicable.
 */
template <int bits, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
  if (std::is_same<IdType, int>::value)
    if (bits > 16)
      return true;
  return false;
#else
  if (bits == 16)
    return false;  // cusparse's SpMM on fp16 is slow, temporally disabled.
  // If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
  return !more_nnz_than_matrix_size;
#endif
}

/*!
 * \brief CUDA implementation of g-SpMM on Csr format.
 * \note use cusparse if the reduce operator is `sum` and there is
 *       no broadcast, use dgl's kernel in other cases.
 */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
             std::vector<NDArray>* vec_out,
             std::vector<std::vector<NDArray>>* out_aux,
             const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id
             const std::vector<dgl_type_t>& out_ntids) {  // output node type id
  bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
  bool use_efeat = op != "copy_lhs";
  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
  SWITCH_BITS(bits, DType, {
    std::vector<DType*> trans_out((*vec_out).size(), NULL);

    bool use_legacy_cusparsemm =
        (CUDART_VERSION < 11000) && (reduce == "sum") &&
        // legacy cuSPARSE does not care about NNZ, hence the argument "false".
        ((op == "copy_lhs" && cusparse_available<bits, IdType>(false)) ||
         (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
    // Create temporary output buffer to store non-transposed output
    if (use_legacy_cusparsemm) {
      for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
        const int m = (*vec_out)[ntype]->shape[0];
        const int n = (*vec_out)[ntype]->shape[1];
        if (m == 0) continue;
        DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
          m * n * sizeof(DType)));
        CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
        trans_out[ntype] = out;
      }
    }
    // Check shape of ufeat for all relation type and compute feature size
    int64_t x_length = 1;
    for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
      NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
      NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
      CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
      for (int i = 1; i < ufeat->ndim; ++i) {
        if (ufeat->shape[i] != next_ufeat->shape[i]) {
          if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
            LOG(FATAL) <<
              "Homogenized message passing on heterogeneous graphs does not support " <<
              "automatic broadcasting.  Please manually broadcast it before calling " <<
              "message passing functions.";
          else
            LOG(FATAL) << "Input features have different shapes.";
          return;
        }

        if (etype == 0)
          x_length *= ufeat->shape[i];
      }
    }
    // TODO(Israt): Can python do the following initializations while creating the tensors?
    if (reduce == "max" ||  reduce == "min") {
      const int64_t dim = bcast.out_len;
      std::vector<bool> updated((*vec_out).size(), false);
      for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
        DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
        if (reduce == "max")
          _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
        else  // min
          _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
        const dgl_type_t dst_id = out_ntids[etype];
        if (!updated[dst_id]) {
          updated[dst_id] = true;
          if (op == "copy_lhs") {
            IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
            _Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
          }
          if (op == "copy_rhs") {
            IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
            _Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
          }
        }
      }
    }

    auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
      const dgl_type_t src_id = ufeat_ntids[etype];
      const dgl_type_t dst_id = out_ntids[etype];
      CSRMatrix csr = vec_csr[etype];
      if (reduce == "sum") {
        bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
          /* Call  SpMM for each relation type */
        if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) {  // cusparse
          /* If CUDA is less than 11.0, put the output in trans_out for later transposition */
          DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
            static_cast<DType*>((*vec_out)[dst_id]->data);
          CusparseCsrmm2Hetero<DType, IdType>(
              csr.indptr->ctx, csr,
              static_cast<DType*>(vec_ufeat[src_id]->data),
              nullptr,
              out,
              x_length, thr_entry->stream);
        } else if (op == "mul" && is_scalar_efeat &&
            cusparse_available<bits, IdType>(more_nnz)) {  // cusparse
          NDArray efeat = vec_efeat[etype];
          if (!IsNullArray(csr.data))
            efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
          CusparseCsrmm2Hetero<DType, IdType>(
              csr.indptr->ctx, csr,
              static_cast<DType*>(vec_ufeat[src_id]->data),
              static_cast<DType*>(efeat->data),
              // TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
              static_cast<DType*>((*vec_out)[dst_id]->data),
              x_length, thr_entry->stream);
        } else {  // general kernel
          NDArray ufeat = (vec_ufeat.size() == 0) ?
            NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ?
            NullArray() : vec_efeat[etype];
          SWITCH_OP(op, Op, {
            cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
                bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
          });
        }
      } else if (reduce == "max") {
          SWITCH_OP(op, Op, {
            NDArray ufeat = (vec_ufeat.size() == 0) ?
                NullArray() : vec_ufeat[src_id];
            NDArray efeat = (vec_efeat.size() == 0) ?
                NullArray() : vec_efeat[etype];
            cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
                bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
                (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
                src_id, etype);
          });
      } else if (reduce == "min") {
          SWITCH_OP(op, Op, {
            NDArray ufeat = (vec_ufeat.size() == 0) ?
                NullArray() : vec_ufeat[src_id];
            NDArray efeat = (vec_efeat.size() == 0) ?
                NullArray() : vec_efeat[etype];
            cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
                bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
                (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
                src_id, etype);
        });
      } else {
        LOG(FATAL) << "Not implemented";
      }
    }

    if (use_legacy_cusparsemm) {
      // transpose output
      for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
        const int m = (*vec_out)[ntype]->shape[0];
        const int n = (*vec_out)[ntype]->shape[1];
        if (m == 0) continue;
        DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
        _Transpose(trans_out[ntype], C_data, n, m);
        device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
      }
    }
  });
}

template void SpMMCsrHetero<kDLGPU, int32_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 16>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 32>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 64>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);



}  // namespace aten
}  // namespace dgl
